import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import json
import sys
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import random
import time
import os
import errno

def mkdir_if_missing(directory):
    if not os.path.exists(directory):
        try:
            os.makedirs(directory)
        except OSError as e:
            if e.errno != errno.EEXIST:
                raise

class classifier_lstm(nn.Module):
    def __init__(self, n_cls, d_atom, d_obj, d_cls = 0):
        super(classifier_lstm, self).__init__()
        self.lstm_oa = nn.LSTM(d_obj, d_atom)
        self.lstm_a_out = nn.LSTM(d_atom, d_atom)
        self.classifier = nn.Linear(d_atom, n_cls)

    def forward(self, atoms, objs):
        # atoms [batch_size, (lr+la) * dim], objs [batch_size, dim]
        batch_size = atoms.shape[0]
        objs = objs.view(1, batch_size, -1)
        atoms = atoms.view(1, batch_size, -1)
        out_obj, (hid_obj, cell_obj) = self.lstm_oa(objs)
        out_atom, (hid_atom, cell_atom) = self.lstm_a_out(atoms, (hid_obj, cell_obj))
        predictions = self.classifier(hid_atom)
        return predictions.squeeze()

def cos_dis_loss(a, b):
    # input [*, dim_vector]
    dim_vector = a.shape[-1]
    if b.shape[-1] != dim_vector:
        raise ValueError('given vectors with different lengths')
    return torch.sum(1 - a.view(-1, dim_vector).mm(b.view(-1, dim_vector).transpose(1,0)))

def lil_sample(lil, n_per_list, flatten = None):
    # sample list of lists
    results = []
    for l in lil:
        ss = []
        for ll in range(len(l)):
            if not flatten:
                ss.append(random.choices(l[ll], k=n_per_list[ll]))
            elif flatten:
                ss = ss+random.choices(l[ll], k=n_per_list[ll])
        results.append(ss)
    return results

def multi_nb_sample(c_ids, edge_lists, n_per_hop, flatten = None):
    # sample from multi-hop neighbors, each hop corresponds to a given number
    # n_per_hop denotes how many examples to sample for each hop
    # returns sampled multi-hop neighbors for each node [batch_size, n_hop]
    n_hop = len(n_per_hop)
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(edge_lists))
    adj_m = adj
    adj_multi = [] # denote multi-hop connections

    # create multi-hop adjacency lists
    t1 = time.time()
    for h in range(n_hop):
        adj_multi.append((adj_m>0)*1)
        adj_m = adj_m.dot(adj)
    t6 = time.time()
    id_map = [k for k in edge_lists]

    t2 = time.time()
    for k in edge_lists:
        l = edge_lists[k]
        for n in l:
            if not n in id_map:
                id_map.append(n)
    t3 = time.time()
    id_map = np.array(id_map)
    adj_list_multi = [] # [batch_size, n_hop, *]
    t4 = time.time()
    for id in c_ids:
        mapped_id = np.where(id_map==id)[0][0]
        nbs_multi = []
        for h in range(n_hop):
            nbs = random.choices(adj_multi[h][mapped_id].indices, k=n_per_hop[h])
            if flatten:
                nbs_multi += [id_map[i] for i in nbs]
            elif not flatten:
                nbs_multi.append([id_map[i] for i in nbs])

        adj_list_multi.append(nbs_multi)
    t5 = time.time()
    #print('t1 to t6 is {}, t6 to t2 is {}, t2 to t3 is {}, t3 to t4 is {}, t4 to t5 is {}'.format(t6-t1, t2-t6, t3-t2, t4-t3, t5-t4))
    return adj_list_multi #, adj_multi

def multi_nb(edge_lists, n_hop, flatten = None):
    # genegrate multi-hop neighbors
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(edge_lists))
    adj_m = adj
    adj_multi = [] # denote multi-hop connections

    # create multi-hop adjacency lists
    for h in range(n_hop):
        adj_multi.append((adj_m>0)*1)
        adj_m = adj_m.dot(adj)
    adj_list_multi = [] # [batch_size, n_hop, *]
    for id in edge_lists.keys():
        nbs_multi = []
        for h in range(n_hop):
            nbs = adj_multi[h][id].indices.tolist()
            if flatten:
                nbs_multi += nbs
            elif not flatten:
                nbs_multi.append(nbs)

        adj_list_multi.append(nbs_multi)
    return adj_list_multi#, adj_multi

def block_diag(mtr_size, block_size):
    # generate a block square diagonal matrix with equally sized blocks
    m = torch.zeros([mtr_size, mtr_size])
    if mtr_size%block_size:
        raise ValueError('matrix size cannot be divided evenly by block size')
    n_blocks = int(mtr_size/block_size)
    for i in range(n_blocks):
        m[i*block_size:(i+1)*block_size, i*block_size:(i+1)*block_size] = 1
    return m

def chunks(lst, n, shuffle = True):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def sample_mask(idx, l):
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=np.bool)

'''
def inf_nan_detect(variable):
    if torch.sum(torch.isinf(variable)) > 0:
        print('detect inf in {}'.str(variable))
        exit()
    if torch.sum(torch.isnan(variable)) > 0:
        print('detect nan in mark6')
        exit()
'''
class MLP_classifier(nn.Module):
    def __init__(self, channel_list):
        super(MLP_classifier, self).__init__()
        self.loss = nn.BCEWithLogitsLoss()
        self.layer_list = nn.ModuleList([nn.Linear(in_channel, out_channel) for (in_channel, out_channel) in channel_list])

    def forward(self, input, labels):
        depth = len(self.layer_list)
        for layer in range(depth):
            input = self.layer_list[layer](input)

        #pred =
        loss = self.loss(input, labels)

        return loss

def onehot_encoding(length, ids):
    if type(ids) == int:
        # if only one index to convert
        encoding = np.zeros(length)
        encoding[ids] = 1
        return encoding
    else:
        output = np.zeros([len(ids), length])
        for idx in range(len(ids)):
            output[idx][ids[idx]] = 1
        return output

def binary_position_encoding(num_digits, ids):
    # convert an integer into a binary encoding. Each 10 positions ecode 1 digit
    if type(ids) == int:
        # if only one index to convert
        encoding = np.zeros([num_digits,10])
        id_ = (num_digits-len(str(ids)))*'0' + str(ids)
        for i in range(num_digits):
            position = int(id_[i])
            encoding[i][position] = 1
        return encoding.reshape(-1)
    else:
        output = []
        for idx in ids:
            encoding = np.zeros([num_digits, 10])
            id_ = (num_digits-len(str(idx)))*'0' + str(idx)
            for i in range(num_digits):
                position = int(id_[i])
                encoding[i][position] = 1
            output.append(encoding.reshape(-1))
        return np.array(output)

def load_data_part_G(data_path, dataset_str, class_ids, n_hop, flatten):
    """
    Loads parts of input data (both features and graph structures) from gcn/data directory for incremental learning setting

    ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
        (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
    ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
    ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
    ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
        object;
    ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.

    All objects above must be saved using python pickle module.

    :param dataset_str: Dataset name
    :return: All data input files loaded (as well the training/test data).
    """
    #num_class = len(class_ids)
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open("{}/data/ind.{}.{}".format(data_path, dataset_str, names[i]), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    num_class_total = y.shape[1]
    test_idx_reorder = parse_index_file("{}/data/ind.{}.test.index".format(data_path, dataset_str))
    test_idx_range = np.sort(test_idx_reorder)

    if dataset_str == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range-min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]

    # adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    class_labels = []
    for c in class_ids:
        label = sample_mask(c, y.shape[1]) * 1
        class_labels.append(label)
    for k in graph:
        jm = (labels[k] == class_labels) # compare label with candidate classes
        jm = np.sum(jm, 1) # if label matches a candidate class exactly, num_class will be in jm after sum
        jm = (jm == num_class_total)
        if not np.any(jm):
            # if a node does not belong to current classes, then isolate it
            graph[k] = [k]
        else:
            to_pop = []
            for t in range(len(graph[k])):
                #print('t is', t)
                jm1 = (labels[graph[k][t]] == class_labels)  # compare label with candidate classes
                jm1 = np.sum(jm1, 1)  # if label matches a candidate class exactly, num_class will be in jm after sum
                jm1 = (jm1 == num_class_total)
                if not np.any(jm1):
                    # if a node connects to a neighbor not in current class, remove this neighbor
                    to_pop.append(graph[k][t])
            for p in to_pop:
                graph[k].remove(p)

    #adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))

    # idx_test = test_idx_range.tolist()
    # idx_val = range(len(y), len(y) + 500)
    # idx_train = range(len(y))
    idx_test_candidate = test_idx_range.tolist()
    idx_test = []
    idx_val_candidate = list(range(len(y), len(y) + 500))
    idx_val = []

    idx_train = []
    for label in class_labels:
        ids_selected = np.matmul(y, label).nonzero()[0].tolist()
        idx_train = idx_train + ids_selected
        for id_can in idx_test_candidate:
            if sum(labels[id_can] * label):
                idx_test.append(id_can)
        for id_can in idx_val_candidate:
            if sum(labels[id_can] * label):
                idx_val.append(id_can)

    train_mask = sample_mask(idx_train, labels.shape[0])
    val_mask = sample_mask(idx_val, labels.shape[0])
    test_mask = sample_mask(idx_test, labels.shape[0])

    y_train = np.zeros(labels.shape)
    y_val = np.zeros(labels.shape)
    y_test = np.zeros(labels.shape)
    y_train[train_mask, :] = labels[train_mask, :]
    y_val[val_mask, :] = labels[val_mask, :]
    y_test[test_mask, :] = labels[test_mask, :]

    # select only the involved classes in the labels
    y_train = y_train[:,class_ids]
    y_val = y_val[:, class_ids]
    y_test = y_test[:, class_ids]
    labels = labels[:, class_ids]

    # return multi-hop lists
    multi_nbs = multi_nb(graph, n_hop, flatten)
    return idx_train, idx_val, idx_test, graph, multi_nbs, features.todense(), y_train, y_val, y_test, labels

def load_newdata_part_G(data_path, dataset_str, class_ids, n_hop, flatten):
    """
    Loads parts of input data (both features and graph structures) from gcn/data directory for incremental learning setting

    ind.dataset_str.x => the feature vectors of the training instances as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.tx => the feature vectors of the test instances as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training instances
        (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.y => the one-hot labels of the labeled training instances as numpy.ndarray object;
    ind.dataset_str.ty => the one-hot labels of the test instances as numpy.ndarray object;
    ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
    ind.dataset_str.graph => a dict in the format {index: [index_of_neighbor_nodes]} as collections.defaultdict
        object;
    ind.dataset_str.test.index => the indices of test instances in graph, for the inductive setting as list object.

    All objects above must be saved using python pickle module.

    :param dataset_str: Dataset name
    :return: All data input files loaded (as well the training/test data).
    """

    # load edges and build the graph
    graph_ = dict()
    with open("{}/{}/out1_graph_edges.txt".format(data_path, dataset_str), 'r') as f:
        edges_str = f.readlines()[1:]
        for edge in edges_str:
            source, target = int(edge.split('\t')[0]), int(edge.split('\t')[1])
            if source in graph_.keys():
                if target not in graph_[source]:
                    graph_[source].append(target)
            elif source not in graph_.keys():
                graph_[source] = [target]

            if target in graph_.keys():
                if source not in graph_[source]:
                    graph_[target].append(source)
            elif target not in graph_.keys():
                graph_[target] = [source]
        n_nodes = len(graph_)
        graph = dict()
        for node in range(n_nodes):
            graph[node]=graph_[node]

    # load features and labels
    if dataset_str == 'film':
        with open("{}/{}/out1_node_feature_label.txt".format(data_path, dataset_str), 'r') as f:
            data_str = f.readlines()
            title = data_str[0]
            body = data_str[1:]
            data_dim = int(title.split(':')[1].split(')')[0])
            n_nodes = len(body)
            features = np.zeros([n_nodes, data_dim])
            labels_ = np.zeros(n_nodes,dtype=int)
            for data in body:
                id_str, feat_str, label_str = data.split('\t')
                id = int(id_str)
                label = int(label_str)
                feat_str_list = feat_str.split(',')
                for feat_id in feat_str_list:
                    features[id][int(feat_id) - 1] = 1.
                labels_[id] = label
    else:
        with open("{}/{}/out1_node_feature_label.txt".format(data_path, dataset_str), 'r') as f:
            data_str = f.readlines()
            body = data_str[1:]
            example = body[0].split('\t')
            data_dim = len(example[1].split(','))
            n_nodes = len(body)
            features = np.zeros([n_nodes, data_dim])
            labels_ = np.zeros(n_nodes,dtype=int)
            for data in body:
                id_str, feat_str, label_str = data.split('\t')
                id = int(id_str)
                label = int(label_str)
                feat_str_list = feat_str.split(',')
                for i, feat in enumerate(feat_str_list):
                    features[id][i] = int(feat)
                labels_[id] = label
    labels = onehot_encoding(int(max(labels_)+1), labels_)
    n_per_label = [sum((labels_ == i).astype(int)) for i in class_ids]
    idx_train, idx_val, idx_test = [], [], []
    for id in range(len(class_ids)):
        ids_cl = (labels_ == class_ids[id]).astype(int).nonzero()[0].tolist()  # ids for examples belionging to current label
        split_tv, split_vt = int(n_per_label[id] * 0.6), int(n_per_label[id] * 0.8)  # index spliting train and val
        idx_train = idx_train + ids_cl[0:split_tv]
        idx_val = idx_val + ids_cl[split_tv:split_vt]
        idx_test = idx_test + ids_cl[split_vt:]


    num_class_total = labels.shape[1]
    class_labels = []
    for c in class_ids:
        label = sample_mask(c, labels.shape[1]) * 1
        class_labels.append(label)
    for k in graph:
        jm = (labels[k] == class_labels) # compare label with candidate classes
        jm = np.sum(jm, 1) # if label matches a candidate class exactly, num_class will be in jm after sum
        jm = (jm == num_class_total)
        if not np.any(jm):
            # if a node does not belong to current classes, then isolate it
            graph[k] = [k]
        else:
            to_pop = []
            for t in range(len(graph[k])):
                jm1 = (labels[graph[k][t]] == class_labels)  # compare label with candidate classes
                jm1 = np.sum(jm1, 1)  # if label matches a candidate class exactly, num_class will be in jm after sum
                jm1 = (jm1 == num_class_total)
                if not np.any(jm1):
                    # if a node connects to a neighbor not in current class, remove this neighbor
                    to_pop.append(graph[k][t])
            for p in to_pop:
                graph[k].remove(p)


    train_mask = sample_mask(idx_train, labels.shape[0])
    val_mask = sample_mask(idx_val, labels.shape[0])
    test_mask = sample_mask(idx_test, labels.shape[0])

    y_train = np.zeros(labels.shape)
    y_val = np.zeros(labels.shape)
    y_test = np.zeros(labels.shape)
    y_train[train_mask, :] = labels[train_mask, :]
    y_val[val_mask, :] = labels[val_mask, :]
    y_test[test_mask, :] = labels[test_mask, :]

    # select only the involved classes in the labels
    y_train = y_train[:,class_ids]
    y_val = y_val[:, class_ids]
    y_test = y_test[:, class_ids]
    labels = labels[:, class_ids]

    # return multi-hop lists
    multi_nbs = multi_nb(graph, n_hop, flatten)
    return idx_train, idx_val, idx_test, graph, multi_nbs, features, y_train, y_val, y_test, labels


def load_data(data_path, dataset_str, class_ids, n_hop, flatten):
    if dataset_str in ['cora','citeseer']:
        return load_data_part_G(data_path, dataset_str, class_ids, n_hop, flatten)
    else:
        return load_newdata_part_G(data_path, dataset_str, class_ids, n_hop, flatten)


class atten_classifier_GAT_o(nn.Module):
    # performance is around 74 on 7 classes, not very good
    def __init__(self, dim_original, dim_attention, num_class, inte = 'add'):
        super(atten_classifier_GAT_o, self).__init__()
        self.inte = inte
        self.w = nn.Linear(dim_original,dim_attention, bias=False)
        self.a = nn.Linear(2*dim_attention, 1, bias=False)
        if inte == 'add':
            self.classifier = nn.Linear(dim_attention, num_class)
        elif inte == 'concat':
            self.classifier = nn.Linear(2*dim_attention, num_class)

    def forward(self, atom, obj, cls = None):
        # atom[batch_size, lr+la, dim], obj[batch_size, dim]
        batch_size = obj.shape[0]
        dim_proto = obj.shape[-1]
        a1 = self.w(atom)
        o1 = self.w(obj).view(batch_size, 1, dim_proto)
        o2 = o1.repeat([1, a1.shape[1], 1])
        att_weights = self.a(torch.cat((a1, o2), dim=-1)).transpose(1,2) #[batch_size, 1, lr+la]
        message_a2o = att_weights.bmm(a1).squeeze() #[batch_size, dim]
        if self.inte == 'sum':
            preds_1 = self.classifier(o1.squeeze()+message_a2o)
        elif self.inte == 'concat':
            preds_1 = self.classifier(torch.cat((o1.squeeze(), message_a2o), dim=-1))
        return preds_1


class MLP(nn.Module):
    def __init__(self, channel_list, activation='relu'): # number of in and out channels required
        super(MLP, self).__init__()
        self.layer_list = nn.ModuleList([nn.Linear(in_channel,out_channel) for (in_channel, out_channel) in channel_list])
        self.softmax = nn.Softmax()
        '''
        activation_cohices = []
        # activation setting
        if type(activation) == list:
            # if each layer is given a different activation function
            self.activation
        '''

        '''
        # initialize
        for layer in self.layer_list:
            fc_init(layer)
        '''

    def forward(self, input):
        depth = len(self.layer_list)
        for layer in range(depth):
            input = self.layer_list[layer](input)
            #input = F.relu(input)
            '''
            if layer < (depth-1):
                input = F.relu(input)
            elif layer == (depth-1):
                output = self.softmax(input)
            '''
        return input

def Pairwise_dis_shrink(x):
    # shrink distance between each pair of vectors in x
    cos_dis = torch.matmul(x, x.transpose(1,0))
    return torch.sum(cos_dis)

def Pairwise_dis_loss(x, d_min=0.1, mask_d = None, batch = None):
    # expect a set of tensor (2-d array), and return a loss to force pairwise distance larger than a threshold
    if batch:
        cos_dis = torch.triu(x.bmm(x.transpose(1,2)), diagonal=1) # get the upper triangle matrix without the diagonal elements
    else:
        cos_dis = torch.triu(x.mm(x.transpose(1,0)), diagonal=1) # get the upper triangle matrix without the diagonal elements
    if not mask_d is None:
        cos_dis = cos_dis*mask_d
    mask = (cos_dis<d_min).float()
    return -torch.sum(cos_dis * mask)

def Pairwise_dis_loss_(x, d_min=0.1):
    # expect a set of tensor (2-d array), and return a loss to force pairwise distance larger than a threshold
    cos_dis = torch.matmul(x, x.transpose(1,0))
    mask = (cos_dis>d_min).float()
    return torch.sum(cos_dis * mask)

class Relation_emb(nn.Module):
    # Given a vertex and its neighbors, embed each pair of vertices into an embedding
    def __init__(self, channel_list):
        super(Relation_emb, self).__init__()
        self.layer_list = nn.ModuleList([nn.Linear(in_channel, out_channel) for (in_channel, out_channel) in channel_list])

    #def forward(self, train_ids, train_features, adj_list, edge_feats=None):

class Component_prototypes(nn.Module):
    def __init__(self, dim_proto, dim_cls, num_atoms, l_a, l_r, num_objs = 5000, num_cls = 10000):
        super(Component_prototypes, self).__init__()
        self.l_a = l_a # number of attribute atomic embeddings to select
        self.l_r = l_r
        self.l_relu = nn.LeakyReLU()

        self.a_o_emb = nn.Linear((l_a+l_r)*dim_proto, dim_proto, bias=True) # embed concat atoms to objs
        self.o_c_emb = nn.Linear(dim_proto, dim_cls) # embed objs to be matched to cls
        self.a_o_attention = nn.Linear((l_a+l_r)*dim_proto, (l_r+l_a)) # attention on atoms for objs
        self.a_o_att_w_GAT = nn.Linear(dim_proto, dim_proto, bias=False)
        self.a_o_att_a_GAT = nn.Linear(dim_proto*2, 1, bias=False)
        self.a_o_mask_emb = nn.Linear((l_a+l_r)*dim_proto, (l_a+l_r)*dim_proto)
        self.atoms = Parameter(torch.empty(num_atoms, dim_proto).uniform_(0,0.5), requires_grad=True)  # atom prototypes
        self.objs = Parameter(torch.empty(num_objs, dim_proto).uniform_(0,0.5), requires_grad=True) # object prototypes
        self.cls = Parameter(torch.empty(num_cls, dim_cls).uniform_(0,0.5), requires_grad=True)  # class prototypes
        self.atom_stat = torch.zeros(num_atoms, requires_grad=False)  # record number of embeds assigned to each atom
        self.obj_stat = torch.zeros(num_objs, requires_grad=False)  # record which nodes have been assigned to each object {k:set}
        self.cls_stat = torch.zeros(num_cls, requires_grad=False) # record which objs have been assigned to each class {k:set}
        self.obj_atom_map = torch.zeros([num_objs,num_atoms], requires_grad=False) # record to which atoms each object connects {obj:{atoms}}
        self.cls_obj_map = torch.zeros([num_cls,num_objs], requires_grad=False) # record to which objs each cls connects

        self.n_atom_total = num_atoms
        self.num_atoms = 1
        self.num_atoms_old = 1
        self.atom_a_splits = [0] # record which embedding module each attr atom corresponds to
        self.atom_r_splits = [0]  # record which embedding module each rela atom corresponds to
        self.num_objs = 1
        self.num_cls = 1

    def update(self, c_ids, embeddings, threshold, est_proto, threshold_c = 0.4, task_id = None): # nei_embs are embeddings of neighboring nodes
        d_proto = embeddings.shape[-1]
        n_AEM = len(self.atom_a_splits) # number of sets of embedding modules
        n_AEM_c = int(embeddings.shape[1]/(self.l_r+self.l_a))
        batch_size = len(c_ids)
        # allocate to GPU
        self.atom_stat = self.atom_stat.cuda(embeddings.get_device())
        self.obj_stat = self.obj_stat.cuda(embeddings.get_device())
        self.cls_stat = self.cls_stat.cuda(embeddings.get_device())

        # deal with current node (atoms)
        emb_set = [embeddings[:, i * (self.l_a + self.l_r): (i + 1) * (self.l_r + self.l_a), :].contiguous().view(-1, d_proto)
                   for i in range(n_AEM_c)] # n_AEM * [batch*(n_AEM_a + lr), d_proto]
        atom_a_splits = self.atom_a_splits.copy()
        if atom_a_splits[-1] == self.num_atoms:
            self.num_atoms+=1
        atom_a_splits.append(self.num_atoms)
        atom_set = [self.atoms[atom_a_splits[i-(n_AEM_c+1)]: atom_a_splits[i + 1-(n_AEM_c+1)]] for i in range(n_AEM_c)]
        soft_corres_atom_set = [emb_set[i].mm(F.normalize(atom_set[i], dim=-1).transpose(1, 0)) for i in
                                range(n_AEM_c)]
        corres_max_set = [torch.max(soft_corres_atom_set[i], dim=1)[0] for i in
                          range(n_AEM_c)]  # get the max value of each row

        c_embs = emb_set[-1]
        c_embs = c_embs.view(-1, d_proto)  # [batch_size*(lr+la), d_proto]

        c_embs_ = c_embs.detach().cpu().numpy()
        num_embs = c_embs.shape[0]
        soft_corres_atom = soft_corres_atom_set[-1]  # cosine dist between embeddings and protos [num_embeddings*num_protos]
        corres_max = corres_max_set[-1]  # get the max value of each row
        new_protos = corres_max < (1 - threshold)  # denote which embeddings need establishing new protos
        #new_protos = (corres_max < (1 - threshold)) * (corres_max!=0)
        new_proto_indices = new_protos.nonzero().squeeze(dim=-1)  # indices of embeddings for new protos

        if est_proto:
            # create new atoms for current nodes
            if len(new_proto_indices) != 0:
                # avoid redundant atoms caused by similar embeddings
                embs_for_new = c_embs[new_proto_indices].detach()
                emb_sim = embs_for_new.mm(embs_for_new.transpose(1, 0))
                emb_sim = torch.triu(emb_sim, diagonal=1)  # get upper tri-matrix
                emb_sim = torch.max(emb_sim, 1)[0]  # get the max similarity between each emb and other embs
                new_proto_indices = (emb_sim < (1 - threshold)).nonzero().squeeze(dim=-1)  #

                with torch.no_grad():
                    self.atoms[self.num_atoms: self.num_atoms + len(new_proto_indices)] = embs_for_new[new_proto_indices]
                    atoms_ = self.atoms.detach().cpu().numpy()
                self.num_atoms += len(new_proto_indices)
                self.num_atoms_old = self.num_atoms
                atom_a_splits[-1] = self.num_atoms
                atom_set = [self.atoms[atom_a_splits[i-(n_AEM_c+1)]: atom_a_splits[i + 1-(n_AEM_c+1)]] for i in range(n_AEM_c)]
                soft_corres_atom = c_embs.mm(F.normalize(atom_set[-1], dim=-1).transpose(1, 0))

            max_logits = soft_corres_atom.max(dim=1)[0]
            sorted_ids = torch.sort(max_logits.view(batch_size, -1), dim=1, descending=True)[1] # [batch_size, lr+la]
            selected_sorted_ids = sorted_ids[:, 0:(self.l_r+self.l_a)].view(batch_size, self.l_r+self.l_a, 1) #[batch_size, lr+la, 1] select fixed number of atoms
            id_batch = torch.tensor(range(batch_size)).view(batch_size,1,1)
            c_num_atoms = int(atom_a_splits[-1] - atom_a_splits[-2])
            id_atom = torch.tensor(range(c_num_atoms)).view(1,1,c_num_atoms)
            max_logits = max_logits.view(num_embs, 1)
            hard_corres_atom = (soft_corres_atom == max_logits).float().view(batch_size, -1, c_num_atoms) # [batch_size, lr+la, n_atoms]
            hard_corres_atom = hard_corres_atom[id_batch, selected_sorted_ids, id_atom].view(batch_size*(self.l_r+self.l_a), c_num_atoms) # select the most close lr+la atoms
            associated_protos = hard_corres_atom.mm(atom_set[-1])
            # update stat of atoms
            self.atom_stat[atom_a_splits[-2]:self.num_atoms] += torch.mean(hard_corres_atom, 0)

            # objs
        else:
            max_logits_ = [soft_corres_atom_set[i].max(dim=1)[0] for i in range(n_AEM_c)] # n_AEM_c * [batch*lr+la]
            votes = [i.mean() for i in max_logits_]
            #print('votes :', votes)
            max_logits = torch.cat([m.view(batch_size, self.l_r+self.l_a) for m in max_logits_], dim=1) # [batch, n_AEM_c * (lr+la)]
            voted_AEM = int(np.argmax(votes))
            if task_id is not None and n_AEM_c>1:
                voted_AEM = task_id
            sorted_ids = torch.sort(max_logits.view(batch_size, -1), dim=1, descending=True)[1]  # [batch_size, lr+la]
            sorted_ids_n = sorted_ids.detach().cpu().numpy()
            selected_sorted_ids = sorted_ids[:, 0:(self.l_r + self.l_a)].view(batch_size, self.l_r + self.l_a,
                                                                              1)  # [batch_size, lr+la, 1] select fixed number of atoms
            '''
            votes = selected_sorted_ids//3
            vote_result = [(votes==i).sum() for i in range(n_AEM_c)]
            voted_AEM = int(np.argmax(vote_result))
            '''

            print('voted AEM is {}'.format(voted_AEM))
            #voted_AEM = 0
            #id_batch = torch.tensor(range(batch_size)).view(batch_size, 1, 1)
            #id_atom = torch.tensor(range(self.num_atoms)).view(1, 1, self.num_atoms)
            #max_logits = max_logits.view(num_embs, 1)
            #hard_corres_atom_ = (soft_corres_atom == max_logits).float().view(batch_size, -1, self.num_atoms)
            #hard_corres_atom__n = hard_corres_atom_.detach().cpu().numpy()
            #hard_corres_atom = hard_corres_atom_[id_batch, selected_sorted_ids, id_atom].view(batch_size * (self.l_r + self.l_a), self.num_atoms)
            hard_corres_atom = (soft_corres_atom_set[voted_AEM] == max_logits_[voted_AEM].view(-1,1)).float()
            hard_corres_atom_n = hard_corres_atom.detach().cpu().numpy()
            associated_protos = hard_corres_atom.mm(atom_set[voted_AEM])

        #associated_protos = hard_corres_atom.mm(self.atoms[0:self.num_atoms])
        associated_protos = F.normalize(associated_protos, dim=-1)

        # for objs

        # 1. use concat of atoms to generate attention weights
        a = associated_protos.view(batch_size, (self.l_r+self.l_a)*d_proto)
        atten_weight = self.a_o_attention(a).view(batch_size, 1, (self.l_r+self.l_a))# [batch_size, 1, l_r+l_a], attention weight for each atom
        #atten_weight = F.softmax(atten_weight, dim=-1) #has certain performance without this line
        b = associated_protos.view(batch_size, self.l_r+self.l_a, d_proto)
        obj_embs = F.normalize(atten_weight.bmm(b).view(batch_size, d_proto), dim=-1) # [batch, d_proto] each node gets an object level proto

        '''
        # 2. use current batch to generate attention weights
        a = associated_protos.view(batch_size, (self.l_r+self.l_a)*d_proto).mean(0)
        atten_weight = self.a_o_attention(a).view(1, (self.l_r+self.l_a), 1)
        atten_weight = F.softmax(atten_weight, 1)
        b = associated_protos.view(batch_size, self.l_r+self.l_a, d_proto)
        c = atten_weight * b
        obj_embs = c.mean(1)
        '''
        '''
        # 3. embed concat of atoms into objs
        a = associated_protos.view(batch_size, (self.l_r + self.l_a) * d_proto)
        obj_embs = self.a_o_emb(a) # [batch_size, d_proto]
        '''
        '''
        # 4. mask the concat of atoms for objs
        a = associated_protos.view(batch_size, (self.l_r + self.l_a) * d_proto)
        mask = self.a_o_mask_emb(a).sigmoid()
        obj_embs = mask*a
        
        # 5. try the mean of atoms as objs
        obj_embs = associated_protos.view(batch_size, (self.l_a+self.l_r), d_proto).mean(1)
        
        # 6. attention (GAT version attention) update each atom then aggregate as an obj
        a = associated_protos.view(batch_size, (self.l_r+self.l_a), d_proto)
        b = self.a_o_att_w_GAT(a).view(batch_size, (self.l_r+self.l_a), 1, d_proto) # [batch_size, lr+la, 1, d_proto]
        c = b.repeat([1, 1, (self.l_r+self.l_a), 1]) # [batch_size, lr+la, lr+la, d_proto]
        d = c.transpose(1,2)
        e = torch.cat([c,d], dim=-1)
        atten_weight = self.a_o_att_a_GAT(e).squeeze()# [batch_size, lr+la, lr+la]
        atten_weight = F.softmax(atten_weight, dim=-1)
        ass_protos_up = atten_weight.bmm(a) # [batch_size, lr+la, d_proto]
        obj_embs = ass_protos_up.view(batch_size * (self.l_a+self.l_r), d_proto)
        '''
        # 7. attention (Transformer version) update atom to get obj

        # deal with current node (objects)
        #embeddings = embeddings.view(-1, d_proto)  # [batch_size*(lr+la), d_proto]
        #n_obj_embs = obj_embs.shape[0]
        soft_corres_obj = obj_embs.mm(F.normalize(self.objs[0:self.num_objs], dim=-1).transpose(1,
                                                                                                     0))  # cosine dist between embeddings and protos [num_embeddings*num_protos]
        corres_max_obj = torch.max(soft_corres_obj, dim=1)[0]  # get the max value of each row
        new_objs = corres_max_obj < (1 - threshold)  # denote which embeddings need establishing new objs
        #new_objs = (corres_max_obj < (1 - threshold))*(corres_max_obj!=0)  # denote which embeddings need establishing new objs
        new_obj_indices = new_objs.nonzero().squeeze(dim=-1)  # indices of positions of embeddings for new objs

        if est_proto:
            # create new atoms for current nodes
            if len(new_obj_indices) != 0:
                # avoid redundant atoms caused by similar embeddings
                obj_embs_for_new = obj_embs[new_obj_indices].detach()
                obj_emb_sim = obj_embs_for_new.mm(obj_embs_for_new.transpose(1, 0))
                obj_emb_sim = torch.triu(obj_emb_sim, diagonal=1)  # get upper tri-matrix
                obj_emb_sim = torch.max(obj_emb_sim, 1)[0] # get the max similarity between each emb and other embs
                new_obj_indices = (obj_emb_sim < (1 - threshold)).nonzero().squeeze(dim=-1)  #

                with torch.no_grad():
                    self.objs[self.num_objs: self.num_objs + len(new_obj_indices)] = obj_embs_for_new[new_obj_indices]
                self.num_objs += len(new_obj_indices)
                self.num_objs_old = self.num_objs
                soft_corres_obj = obj_embs.mm(F.normalize(self.objs[0:self.num_objs], dim=-1).transpose(1, 0))

            max_logits = soft_corres_obj.max(dim=1)[0]
            max_logits = max_logits.view(batch_size, 1)
            hard_corres_obj = (soft_corres_obj == max_logits).float()
            # update stat of atoms
            self.obj_stat[0:self.num_objs] += torch.mean(hard_corres_obj, 0)

            # objs
        else:
            max_logits = soft_corres_obj.max(dim=1)[0]
            max_logits = max_logits.view(batch_size, 1)
            hard_corres_obj = (soft_corres_obj == max_logits).float()

        associated_objs = hard_corres_obj.mm(self.objs[0:self.num_objs])
        associated_objs = F.normalize(associated_objs, dim=-1)

        # deal with current node (cls)
        # cls_embs = self.o_c_emb(associated_objs)
        cls_embs = associated_objs
        soft_corres_cls = cls_embs.mm(F.normalize(self.cls[0:self.num_cls], dim=-1).transpose(1,
                                                                                              0))  # cosine dist between embeddings and protos [num_embs*num_protos]
        corres_max_cls = torch.max(soft_corres_cls, dim=1)[0]  # get the max value of each row
        new_cls = corres_max_cls < (1 - threshold_c)  # denote which embeddings need establishing new objs
        #new_cls = (corres_max_cls < (1 - threshold_c))*(corres_max_cls!=0)  # denote which embeddings need establishing new objs
        new_cls_indices = new_cls.nonzero().squeeze(dim=-1)  # indices of positions of embeddings for new objs

        if est_proto:
            # create new objs for current nodes
            if len(new_cls_indices) != 0:
                # avoid redundant atoms caused by similar embeddings
                cls_embs_for_new = F.normalize(cls_embs[new_cls_indices], p=2, dim=-1).detach()
                cls_emb_sim = cls_embs_for_new.mm(cls_embs_for_new.transpose(1, 0))
                cls_emb_sim = torch.triu(cls_emb_sim, diagonal=1)  # get upper tri-matrix
                cls_emb_sim = torch.max(cls_emb_sim, 1)[0]  # get the max similarity between each emb and other embs
                new_cls_indices = (cls_emb_sim < (1 - threshold_c)).nonzero().squeeze(dim=-1)  #

                with torch.no_grad():
                    self.cls[self.num_cls: self.num_cls + len(new_cls_indices)] = cls_embs_for_new[new_cls_indices]
                self.num_cls += len(new_cls_indices)
                self.num_cls_old = self.num_cls
                soft_corres_cls = cls_embs.mm(F.normalize(self.cls[0:self.num_cls], dim=-1).transpose(1, 0))

            max_logits = soft_corres_cls.max(dim=1)[0]
            max_logits = max_logits.view(batch_size, 1)
            hard_corres_cls = (soft_corres_cls == max_logits).float()
            # update stat of objs
            self.cls_stat[0:self.num_cls] += torch.mean(hard_corres_cls, 0)

        else:
            max_logits = soft_corres_cls.max(dim=1)[0]
            max_logits = max_logits.view(batch_size, 1)
            hard_corres_cls = (soft_corres_cls == max_logits).float()

        associated_cls = hard_corres_cls.mm(self.cls[0:self.num_cls])
        associated_cls = F.normalize(associated_cls, dim=-1)

        return associated_protos, associated_objs, associated_cls, hard_corres_atom, hard_corres_obj, selected_sorted_ids

    def AFE_select(self, c_ids, embeddings, threshold, est_proto, threshold_c=0.4,
               task_id=None):  # nei_embs are embeddings of neighboring nodes
        d_proto = embeddings.shape[-1]
        n_AEM = len(self.atom_a_splits)  # number of sets of embedding modules
        n_AEM_c = int(embeddings.shape[1] / (self.l_r + self.l_a))
        batch_size = len(c_ids)
        # allocate to GPU
        self.atom_stat = self.atom_stat.cuda(embeddings.get_device())
        self.obj_stat = self.obj_stat.cuda(embeddings.get_device())
        self.cls_stat = self.cls_stat.cuda(embeddings.get_device())

        # deal with current node (atoms)
        emb_set = [
            embeddings[:, i * (self.l_a + self.l_r): (i + 1) * (self.l_r + self.l_a), :].contiguous().view(-1, d_proto)
            for i in range(n_AEM_c)]  # n_AEM * [batch*(n_AEM_a + lr), d_proto]
        atom_a_splits = self.atom_a_splits.copy()
        if atom_a_splits[-1] == self.num_atoms:
            self.num_atoms += 1
        atom_a_splits.append(self.num_atoms)
        atom_set = [self.atoms[atom_a_splits[i - (n_AEM_c + 1)]: atom_a_splits[i + 1 - (n_AEM_c + 1)]] for i in
                    range(n_AEM_c)]
        soft_corres_atom_set = [emb_set[i].mm(F.normalize(atom_set[i], dim=-1).transpose(1, 0)) for i in
                                range(n_AEM_c)]
        corres_max_set = [torch.max(soft_corres_atom_set[i], dim=1)[0] for i in
                          range(n_AEM_c)]  # get the max value of each row

        c_embs = emb_set[-1]
        c_embs = c_embs.view(-1, d_proto)  # [batch_size*(lr+la), d_proto]

        c_embs_ = c_embs.detach().cpu().numpy()
        num_embs = c_embs.shape[0]
        soft_corres_atom = soft_corres_atom_set[
            -1]  # cosine dist between embeddings and protos [num_embeddings*num_protos]
        corres_max = corres_max_set[-1]  # get the max value of each row
        new_protos = corres_max < (1 - threshold)  # denote which embeddings need establishing new protos
        # new_protos = (corres_max < (1 - threshold)) * (corres_max!=0)
        new_proto_indices = new_protos.nonzero().squeeze(dim=-1)  # indices of embeddings for new protos

        max_logits_ = [soft_corres_atom_set[i].max(dim=1)[0] for i in range(n_AEM_c)]  # n_AEM_c * [batch*lr+la]
        votes = [i.mean() for i in max_logits_]
        # print('votes :', votes)
        max_logits = torch.cat([m.view(batch_size, self.l_r + self.l_a) for m in max_logits_],
                               dim=1)  # [batch, n_AEM_c * (lr+la)]
        voted_AEM = int(np.argmax(votes))
        sorted_ids = torch.sort(max_logits.view(batch_size, -1), dim=1, descending=True)[1]  # [batch_size, lr+la]

        return voted_AEM

    def cls_update(self, associated_atom_ids, threshold_c):
        associated_atom = torch.matmul(associated_atom_ids.view(1, -1), self.cur_atoms)
        new_obj = torch.mean(associated_atom, 0).view(-1,1)
        co_dis = self.cur_cls.mm(new_obj).view(1,-1) # obj-cls cosine distances
        j = max(co_dis.view(-1))
        print('j is', j)
        print('codis is', co_dis)
        if max(co_dis.view(-1)) > (1-threshold_c):
            # if the associated cls already exists
            co_map = F.gumbel_softmax(co_dis, tau=0.1, hard=True).squeeze(dim=0)
            cls_id = co_map.nonzero().squeeze().item()  # return id of the associated cls
            print('comap is', co_map)
            print('cls id',cls_id)
            self.class_stat[cls_id].add(self.num_objs)
        else:
            # if a new cls needs to be established
            self.class_stat[self.num_cls] = set([self.num_objs])
            self.cls_obj_map[self.num_cls, self.num_objs] = 1
            self.num_cls += 1

    def normalize(self, indices=None):
        if not indices: # if user doesn't assign which prototypes to be updated
            save1 = self.atoms
            self.cur_atoms = F.normalize(self.cur_atoms, p=2, dim=1) # normalize each prototype embedding into unit ball
            if torch.sum(torch.isinf(self.atoms)) > 0:
                print('detect inf in 17')
                exit()
            if torch.sum(torch.isnan(self.atoms)) > 0:
                print('detect nan in 17')
                torch.save(save1, '/store/useless/save1.pt')
                torch.save(self.atoms, '/store/useless/save2.pt')
                exit()

        else:
            self.atoms[indices] = F.normalize(self.atoms[indices], p=2, dim=1) # only normalize the selected prototypes for efficiency
            if torch.sum(torch.isinf(self.atoms[indices])) > 0:
                print('detect inf in 18')
                exit()
            if torch.sum(torch.isnan(self.atoms[indices])) > 0:
                print('detect nan in 18')
                exit()

def data_preprocess():
    # this code preprocess the data provided by graphSAGE
    # return a feature map in which the one-hot indices of 60000 dimensions are concatenated with features of 500 dimensions
    ppi_G = json.load(open('/store/ppi_graphSAGE/ppi/ppi-G.json'))
    ppi_id_map = json.load(open('/store/ppi_graphSAGE/ppi/ppi-id_map.json'))
    ppi_feats = np.load('/store/ppi_graphSAGE/ppi/ppi-feats.npy')
    links = ppi_G['links']

    # data preprocess
    edges = []
    for link in links:
        mark = 100
        current_id = link['source']
        if mark != current_id:
            edges.append([])
        mark = current_id
        edges[current_id].append(
            link['target'])  # output a list in which the i-th is a list containing neighbors of node i
    print('mark 7')
    ## convert dict of id into a list
    id_list = []
    for id in ppi_id_map:
        id_list.append(ppi_id_map[id])

    ## obtain the binary encoding of the id map
    #ids = binary_position_encoding(5, id_list) # a 2-d np array containing binary encoding of each id
    ids = onehot_encoding(60000, id_list)  # a 2-d np array containing binary encoding of each id

    ## concatenate ids with features
    id_features = np.concatenate([ids, ppi_feats], 1).astype(np.float32)
    np.save('/store/ppi_graphSAGE/ppi/ids', ids)
